#pragma once

#include <NvInfer.h>
#include <NvInferVersion.h>
#include <NvOnnxConfig.h>
#include <NvOnnxParser.h>
#include <assert.h>
#include <cuda_runtime_api.h>

#include <chrono>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>

#define CUDA_CHECK(status)                                               \
  do {                                                                   \
    auto ret = (status);                                                 \
    if (ret != cudaSuccess) {                                            \
      std::cout << "CUDA failed with error code: " << ret                \
                << ", reason: " << cudaGetErrorString(ret) << std::endl; \
      exit(1);                                                           \
    }                                                                    \
  } while (0)

class MyLogger : public nvinfer1::ILogger {
 public:
  MyLogger(Severity severity = Severity::kWARNING)
      : reportableSeverity(severity) {}

  void log(Severity severity, const char *msg) noexcept override {
    if (severity > reportableSeverity) return;
    switch (severity) {
      case Severity::kINTERNAL_ERROR:
        std::cerr << "INTERNAL_ERROR: ";
        break;
      case Severity::kERROR:
        std::cerr << "ERROR: ";
        break;
      case Severity::kWARNING:
        std::cerr << "WARNING: ";
        break;
      case Severity::kINFO:
        std::cerr << "INFO: ";
        break;
      default:
        std::cerr << "UNKNOWN: ";
        break;
    }
    std::cerr << msg << std::endl;
  }

  Severity reportableSeverity;
};

class TensorrtOnnxInference {
 public:
  TensorrtOnnxInference(const std::string &model_path)
      : model_path_(model_path) {}

  virtual ~TensorrtOnnxInference();

  bool Init();

  bool Infer(const std::vector<float *> &input_data,
             std::vector<const float *> &output_data);

  std::pair<int, int> GetModelInputDims(const int index) const;

  std::pair<int, int> GetModelOutputDims(const int index) const;

 private:
  void doInference(const std::vector<float *> &input_data,
                   std::vector<const float *> &output_data);
  void AllocateBuffers();

  nvinfer1::ICudaEngine *SerializeToEngineFile(const std::string &model_path,
                                               const std::string &engine_path);

  nvinfer1::ICudaEngine *LoadFromEngineFile(const std::string &engine_path);

 private:
  std::string model_path_;

  cudaStream_t stream_;
  void *buffers_[4] = {nullptr};
  std::vector<float *> output_buffers_;
  std::vector<int> input_indexes_;
  std::vector<int> output_indexes_;
  std::vector<int> input_sizes_;
  std::vector<int> output_sizes_;
  std::vector<std::pair<int, int>> model_input_dims_;
  std::vector<std::pair<int, int>> model_output_dims_;
  nvinfer1::ICudaEngine *engine_;
  nvinfer1::IExecutionContext *context_;
  MyLogger gLogger_;
};
